import fire
from transformers import AutoModelForCausalLM
from peft import PeftModel
from TSPA import TSPA

def main(
    base_model_id: str,
    adapter_ids: list[str],
    weights: list[float],
    combination_type: str = "linear",
    output_path: str = "merged_models",
    device: str = "cuda"
):
    # (A) Load base model
    model = AutoModelForCausalLM.from_pretrained(base_model_id, device_map="auto")

    # (B) Load adapters
    adapter_names = [adapter_id.split("/")[-1] for adapter_id in adapter_ids]
    peft_model = PeftModel.from_pretrained(
        model,
        adapter_ids[0],
        adapter_name=adapter_names[0],
    )
    for i in range(1, len(adapter_ids)):
        peft_model.load_adapter(adapter_ids[i], adapter_name=adapter_names[i])

    # (C) Create a TSPA and compute the aligned adapters
    tspa = TSPA(
        peft_model=peft_model,
        adapter_names=adapter_names,
        weights=weights,
        device=device
    )
    tspa.compute_aligned_adapter()

    # (D) Merge adapters
    merged_name = f"TSPA_{combination_type}_w_{'_'.join([str(w).replace('.', '') for w in weights])}"
    peft_model.add_weighted_adapter(adapter_names, weights, merged_name, combination_type=combination_type)

    # (E) Save the merged model
    peft_model.set_adapter(merged_name)
    for adapter_name in adapter_names:
        peft_model.delete_adapter(adapter_name)
    
    output_path = f"{output_path}/{base_model_id.split('/')[-1]}/{'_'.join(adapter_names)}"
    peft_model.save_pretrained(output_path)
    print(f"Aligned model saved to {output_path}/{merged_name}")

if __name__ == "__main__":
    fire.Fire(main)